import os
import base64
from io import BytesIO

import tqdm
import pandas as pd
import torch
from facenet_pytorch import MTCNN
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.ticker import NullLocator



def annotate(img, x1, y1, x2, y2):
    plt.figure()
    _, ax = plt.subplots(1)
    ax.imshow(img)

    box_w = x2 - x1
    box_h = y2 - y1
    bbox = patches.Rectangle(
        (x1, y1), box_w, box_h, linewidth=2, edgecolor="b", facecolor="none"
    )
    ax.add_patch(bbox)

    plt.axis("off")
    ax.xaxis.set_major_locator(NullLocator())
    ax.yaxis.set_major_locator(NullLocator())


def annotate_faces(
    fp_imgs: str,
    output_dir: str,
    prob_thresh: float = 0.95,
    max_n_ppl: int = 3,
    top_n_ranks: int = 15,
):

    # load and merge data
    imgs = pd.read_csv(fp_imgs)

    # filter to top n ranks, unique images
    imgs = imgs.loc[imgs.img_rank < top_n_ranks, :]
    imgs = imgs.sort_values(by="n_people", ascending=False)
    uniq_imgs = imgs.drop_duplicates(subset="img_id")
    uniq_imgs = uniq_imgs.reset_index(drop=True)

    # load model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    mtcnn = MTCNN(keep_all=True, select_largest=False, device=device)

    # annotations
    os.makedirs(output_dir, exist_ok=True)
    for img_id, img_data in tqdm.tqdm(
        zip(uniq_imgs.img_id, uniq_imgs.img_b64), total=len(uniq_imgs)
    ):
        img = Image.open(BytesIO(base64.b64decode(img_data)))
        img = img.convert("RGB")
        boxes, probs = mtcnn.detect(img)

        if boxes is None:
            continue
        
        mask = probs >= prob_thresh
        boxes = boxes[mask]

        for i, (x1, y1, x2, y2) in enumerate(boxes[:max_n_ppl]):
            annotate(img, x1, y1, x2, y2)
            out = os.path.join(output_dir, f"{img_id}-{i}.jpg")
            plt.savefig(out, bbox_inches="tight", pad_inches=0.0)
            plt.close("all")


if __name__ == "__main__":
    annotate_faces()
